# -*- coding: utf-8 -*-
# @Time : 2021-11-17 18:50
# @Author : lwb
# @Site : 
# @File : CNNprocess.py
import torch
from  torch.nn.utils.rnn import pad_sequence
# 这个函数也可不用，只需要将他的操作，在dataset中return 数据的时候应该就行了
# collate_fn参数指向一个函数，用于对一个批次的样本进行整理，如将其转换成张量
def collate_fn(examples):
    inputs = [torch.tensor(ex[0]) for ex in examples]
    # 输出的目标targets为该批次中全部样例输出结果（0或者1）构成的张量
    targets = torch.tensor([ex[1] for ex in examples], dtype=torch.long)
    # 对批次内的样本进行补齐，使其具有相同长度
    inputs = pad_sequence(inputs,batch_first=True)   # pad_sequence函数实现补齐（Padding）功能，使得一个批次中全部序列长度相同（同最大长度序列），不足的默认使用0补齐。
    return inputs, targets